svg_heatmap can be used as a drop-in replacement for seaborn.heatmap, with the exception of a few missing features:
import numpy as np; np.random.seed(0)
import seaborn as sns; sns.set()
from matplotlib import pyplot as plt
from svg_heatmap import heatmap
from ipywidgets import HTML
import sys
from io import BytesIO
import binascii
def compare_plots(data, svg_kwargs={}, sns_kwargs={}, **kwargs):
svg_plot = heatmap(data, **svg_kwargs ,**kwargs)
fig=plt.figure()
sns.heatmap(data, **sns_kwargs, **kwargs)
plt.tight_layout()
with BytesIO() as buf:
fig.canvas.print_png(buf)
png_data = binascii.b2a_base64(buf.getvalue()).decode()
png_html = '<img src="data:image/png;base64,{}">'
sns_png_plot= png_html.format(png_data)
with BytesIO() as buf:
plt.savefig(buf, format='svg')
sns_svg_plot = buf.getvalue().decode()
plt.close()
svg_size, sns_png_size, sns_svg_size = [str(round(sys.getsizeof(plot) / 1024, 1)) + 'kB'
for plot in (svg_plot, sns_png_plot, sns_svg_plot)]
return HTML('svg {}<br>'.format(svg_size) + svg_plot +'<br>sns svg {}<br>'.format(sns_svg_size) + sns_svg_plot
+'<br>sns png {}<br>'.format(sns_png_size) + sns_png_plot)
ndarray data¶compare_plots(np.random.rand(10, 12), cmap='viridis')
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
compare_plots(flights, cmap='magma')
compare_plots(flights, cmap='magma', cbar=False)
data_w_outliers = np.random.rand(10, 12)
data_w_outliers[2:3,3:6] += 5*data_w_outliers.max()
from matplotlib.colors import LogNorm
compare_plots(data_w_outliers, cmap='magma', svg_kwargs=dict(log_scaling=True),
sns_kwargs=dict(norm=LogNorm(vmin=data_w_outliers.min(), vmax=data_w_outliers.max())))
cbars are generated with pyplot and embedded in the root SVG. This can be done by either base64 encoding the png output or using the SVG output.
However, the SVG output is significantly larger since it uses <path>s instead of <text>.
svg_cbar_plot = heatmap(np.random.rand(10, 12), svg_cbar=True)
png_cbar_plot = heatmap(np.random.rand(10, 12), svg_cbar=False)
svg_cbar_size, png_cbar_size = [str(round(sys.getsizeof(plot) / 1024, 1)) + 'kB'
for plot in (svg_cbar_plot, png_cbar_plot)]
HTML('svg cbar {}<br>'.format(svg_cbar_size) + svg_cbar_plot +
'<br>png cbar {}<br>'.format(png_cbar_size) + png_cbar_plot)